<!DOCTYPE html>
<html lang="en">
<head>
    <meta http-equiv="content-type" content="text/html;charset=utf-8"/>
    <meta name="viewport" content="width=device-width, initial-scale=1.0"/>
    <meta name="description" content="This is an annotated implementation/tutorial of the Masked Language Model in PyTorch."/>

    <meta name="twitter:card" content="summary"/>
    <meta name="twitter:image:src" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta name="twitter:title" content="Masked Language Model"/>
    <meta name="twitter:description" content="This is an annotated implementation/tutorial of the Masked Language Model in PyTorch."/>
    <meta name="twitter:site" content="@labmlai"/>
    <meta name="twitter:creator" content="@labmlai"/>

    <meta property="og:url" content="https://nn.labml.ai/transformers/mlm/index.html"/>
    <meta property="og:title" content="Masked Language Model"/>
    <meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
    <meta property="og:site_name" content="Masked Language Model"/>
    <meta property="og:type" content="object"/>
    <meta property="og:title" content="Masked Language Model"/>
    <meta property="og:description" content="This is an annotated implementation/tutorial of the Masked Language Model in PyTorch."/>

    <title>Masked Language Model</title>
    <link rel="shortcut icon" href="/icon.png"/>
    <link rel="stylesheet" href="../../pylit.css?v=1">
    <link rel="canonical" href="https://nn.labml.ai/transformers/mlm/index.html"/>
    <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.13.18/dist/katex.min.css" integrity="sha384-zTROYFVGOfTw7JV7KUu8udsvW2fx4lWOsCEDqhBreBwlHI4ioVRtmIvEThzJHGET" crossorigin="anonymous">

    <!-- Global site tag (gtag.js) - Google Analytics -->
    <script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
    <script>
        window.dataLayer = window.dataLayer || [];

        function gtag() {
            dataLayer.push(arguments);
        }

        gtag('js', new Date());

        gtag('config', 'G-4V3HC8HBLH');
    </script>
</head>
<body>
<div id='container'>
    <div id="background"></div>
    <div class='section'>
        <div class='docs'>
            <p>
                <a class="parent" href="/">home</a>
                <a class="parent" href="../index.html">transformers</a>
                <a class="parent" href="index.html">mlm</a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations" target="_blank">
                    <img alt="Github"
                         src="https://img.shields.io/github/stars/labmlai/annotated_deep_learning_paper_implementations?style=social"
                         style="max-width:100%;"/></a>
                <a href="https://twitter.com/labmlai" rel="nofollow" target="_blank">
                    <img alt="Twitter"
                         src="https://img.shields.io/twitter/follow/labmlai?style=social"
                         style="max-width:100%;"/></a>
            </p>
            <p>
                <a href="https://github.com/labmlai/annotated_deep_learning_paper_implementations/tree/master/labml_nn/transformers/mlm/__init__.py" target="_blank">
                    View code on Github</a>
            </p>
        </div>
    </div>
    <div class='section' id='section-0'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-0'>#</a>
            </div>
            <h1>Masked Language Model (MLM)</h1>
<p>This is a <a href="https://pytorch.org">PyTorch</a> implementation of the Masked Language Model (MLM)  used to pre-train the BERT model introduced in the paper <a href="https://arxiv.org/abs/1810.04805">BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding</a>.</p>
<h2>BERT Pretraining</h2>
<p>BERT model is a transformer model. The paper pre-trains the model using MLM and with next sentence prediction. We have only implemented MLM here.</p>
<h3>Next sentence prediction</h3>
<p>In <em>next sentence prediction</em>, the model is given two sentences <code  class="highlight"><span></span><span class="n">A</span></code>
 and <code  class="highlight"><span></span><span class="n">B</span></code>
 and the model makes a binary prediction whether <code  class="highlight"><span></span><span class="n">B</span></code>
 is the sentence that follows <code  class="highlight"><span></span><span class="n">A</span></code>
 in the actual text. The model is fed with actual sentence pairs 50% of the time and random pairs 50% of the time. This classification is done while applying MLM. <em>We haven&#x27;t implemented this here.</em></p>
<h2>Masked LM</h2>
<p>This masks a percentage of tokens at random and trains the model to predict the masked tokens. They <strong>mask 15% of the tokens</strong> by replacing them with a special <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
 token.</p>
<p>The loss is computed on predicting the masked tokens only. This causes a problem during fine-tuning and actual usage since there are no <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
 tokens  at that time. Therefore we might not get any meaningful representations.</p>
<p>To overcome this <strong>10% of the masked tokens are replaced with the original token</strong>, and another <strong>10% of the masked tokens are replaced with a random token</strong>. This trains the model to give representations about the actual token whether or not the input token at that position is a <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
. And replacing with a random token causes it to give a representation that has information from the context as well; because it has to use the context to fix randomly replaced tokens.</p>
<h2>Training</h2>
<p>MLMs are harder to train than autoregressive models because they have a smaller training signal. i.e. only a small percentage of predictions are trained per sample.</p>
<p>Another problem is since the model is bidirectional, any token can see any other token. This makes the &quot;credit assignment&quot; harder. Let&#x27;s say you have the character level model trying to predict <code  class="highlight"><span></span><span class="n">home</span> <span class="o">*</span><span class="n">s</span> <span class="n">where</span> <span class="n">i</span> <span class="n">want</span> <span class="n">to</span> <span class="n">be</span></code>
. At least during the early stages of the training, it&#x27;ll be super hard to figure out why the replacement for <code  class="highlight"><span></span><span class="o">*</span></code>
 should be <code  class="highlight"><span></span><span class="n">i</span></code>
, it could be anything from the whole sentence. Whilst, in an autoregressive setting the model will only have to use <code  class="highlight"><span></span><span class="n">h</span></code>
 to predict <code  class="highlight"><span></span><span class="n">o</span></code>
 and <code  class="highlight"><span></span><span class="n">hom</span></code>
 to predict <code  class="highlight"><span></span><span class="n">e</span></code>
 and so on. So the model will initially start predicting with a shorter context first and then learn to use longer contexts later. Since MLMs have this problem it&#x27;s a lot faster to train if you start with a smaller sequence length initially and then use a longer sequence length later.</p>
<p>Here is <a href="experiment.html">the training code</a> for a simple MLM model.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">65</span><span></span><span class="kn">from</span> <span class="nn">typing</span> <span class="kn">import</span> <span class="n">List</span>
<span class="lineno">66</span>
<span class="lineno">67</span><span class="kn">import</span> <span class="nn">torch</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-1'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-1'>#</a>
            </div>
            <h2>Masked LM (MLM)</h2>
<p>This class implements the masking procedure for a given batch of token sequences.</p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">70</span><span class="k">class</span> <span class="nc">MLM</span><span class="p">:</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-2'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-2'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">padding_token</span></code>
 is the padding token <code  class="highlight"><span></span><span class="p">[</span><span class="n">PAD</span><span class="p">]</span></code>
.  We will use this to mark the labels that shouldn&#x27;t be used for loss calculation. </li>
<li><code  class="highlight"><span></span><span class="n">mask_token</span></code>
 is the masking token <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
. </li>
<li><code  class="highlight"><span></span><span class="n">no_mask_tokens</span></code>
 is a list of tokens that should not be masked. This is useful if we are training the MLM with another task like classification at the same time, and we have tokens such as <code  class="highlight"><span></span><span class="p">[</span><span class="n">CLS</span><span class="p">]</span></code>
 that shouldn&#x27;t be masked. </li>
<li><code  class="highlight"><span></span><span class="n">n_tokens</span></code>
 total number of tokens (used for generating random tokens) </li>
<li><code  class="highlight"><span></span><span class="n">masking_prob</span></code>
 is the masking probability </li>
<li><code  class="highlight"><span></span><span class="n">randomize_prob</span></code>
 is the probability of replacing with a random token </li>
<li><code  class="highlight"><span></span><span class="n">no_change_prob</span></code>
 is the probability of replacing with original token</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">77</span>    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="o">*</span><span class="p">,</span>
<span class="lineno">78</span>                 <span class="n">padding_token</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">mask_token</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">no_mask_tokens</span><span class="p">:</span> <span class="n">List</span><span class="p">[</span><span class="nb">int</span><span class="p">],</span> <span class="n">n_tokens</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span>
<span class="lineno">79</span>                 <span class="n">masking_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.15</span><span class="p">,</span> <span class="n">randomize_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span> <span class="n">no_change_prob</span><span class="p">:</span> <span class="nb">float</span> <span class="o">=</span> <span class="mf">0.1</span><span class="p">,</span>
<span class="lineno">80</span>                 <span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-3'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-3'>#</a>
            </div>
            
        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">93</span>        <span class="bp">self</span><span class="o">.</span><span class="n">n_tokens</span> <span class="o">=</span> <span class="n">n_tokens</span>
<span class="lineno">94</span>        <span class="bp">self</span><span class="o">.</span><span class="n">no_change_prob</span> <span class="o">=</span> <span class="n">no_change_prob</span>
<span class="lineno">95</span>        <span class="bp">self</span><span class="o">.</span><span class="n">randomize_prob</span> <span class="o">=</span> <span class="n">randomize_prob</span>
<span class="lineno">96</span>        <span class="bp">self</span><span class="o">.</span><span class="n">masking_prob</span> <span class="o">=</span> <span class="n">masking_prob</span>
<span class="lineno">97</span>        <span class="bp">self</span><span class="o">.</span><span class="n">no_mask_tokens</span> <span class="o">=</span> <span class="n">no_mask_tokens</span> <span class="o">+</span> <span class="p">[</span><span class="n">padding_token</span><span class="p">,</span> <span class="n">mask_token</span><span class="p">]</span>
<span class="lineno">98</span>        <span class="bp">self</span><span class="o">.</span><span class="n">padding_token</span> <span class="o">=</span> <span class="n">padding_token</span>
<span class="lineno">99</span>        <span class="bp">self</span><span class="o">.</span><span class="n">mask_token</span> <span class="o">=</span> <span class="n">mask_token</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-4'>
        <div class='docs doc-strings'>
            <div class='section-link'>
                <a href='#section-4'>#</a>
            </div>
            <ul><li><code  class="highlight"><span></span><span class="n">x</span></code>
 is the batch of input token sequences.  It&#x27;s a tensor of type <code  class="highlight"><span></span><span class="n">long</span></code>
 with shape <code  class="highlight"><span></span><span class="p">[</span><span class="n">seq_len</span><span class="p">,</span> <span class="n">batch_size</span><span class="p">]</span></code>
.</li></ul>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">101</span>    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">x</span><span class="p">:</span> <span class="n">torch</span><span class="o">.</span><span class="n">Tensor</span><span class="p">):</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-5'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-5'>#</a>
            </div>
            <p>Mask <code  class="highlight"><span></span><span class="n">masking_prob</span></code>
 of tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">108</span>        <span class="n">full_mask</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">masking_prob</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-6'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-6'>#</a>
            </div>
            <p>Unmask <code  class="highlight"><span></span><span class="n">no_mask_tokens</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">110</span>        <span class="k">for</span> <span class="n">t</span> <span class="ow">in</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_mask_tokens</span><span class="p">:</span>
<span class="lineno">111</span>            <span class="n">full_mask</span> <span class="o">&amp;=</span> <span class="n">x</span> <span class="o">!=</span> <span class="n">t</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-7'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-7'>#</a>
            </div>
            <p>A mask for tokens to be replaced with original tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">114</span>        <span class="n">unchanged</span> <span class="o">=</span> <span class="n">full_mask</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">no_change_prob</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-8'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-8'>#</a>
            </div>
            <p>A mask for tokens to be replaced with a random token </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">116</span>        <span class="n">random_token_mask</span> <span class="o">=</span> <span class="n">full_mask</span> <span class="o">&amp;</span> <span class="p">(</span><span class="n">torch</span><span class="o">.</span><span class="n">rand</span><span class="p">(</span><span class="n">x</span><span class="o">.</span><span class="n">shape</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span> <span class="o">&lt;</span> <span class="bp">self</span><span class="o">.</span><span class="n">randomize_prob</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-9'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-9'>#</a>
            </div>
            <p>Indexes of tokens to be replaced with random tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">118</span>        <span class="n">random_token_idx</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">nonzero</span><span class="p">(</span><span class="n">random_token_mask</span><span class="p">,</span> <span class="n">as_tuple</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-10'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-10'>#</a>
            </div>
            <p>Random tokens for each of the locations </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">120</span>        <span class="n">random_tokens</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">randint</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">n_tokens</span><span class="p">,</span> <span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="n">random_token_idx</span><span class="p">[</span><span class="mi">0</span><span class="p">]),),</span> <span class="n">device</span><span class="o">=</span><span class="n">x</span><span class="o">.</span><span class="n">device</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-11'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-11'>#</a>
            </div>
            <p>The final set of tokens that are going to be replaced by <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
 </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">122</span>        <span class="n">mask</span> <span class="o">=</span> <span class="n">full_mask</span> <span class="o">&amp;</span> <span class="o">~</span><span class="n">random_token_mask</span> <span class="o">&amp;</span> <span class="o">~</span><span class="n">unchanged</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-12'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-12'>#</a>
            </div>
            <p>Make a clone of the input for the labels </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">125</span>        <span class="n">y</span> <span class="o">=</span> <span class="n">x</span><span class="o">.</span><span class="n">clone</span><span class="p">()</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-13'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-13'>#</a>
            </div>
            <p>Replace with <code  class="highlight"><span></span><span class="p">[</span><span class="n">MASK</span><span class="p">]</span></code>
 tokens; note that this doesn&#x27;t include the tokens that will have the original token unchanged and those that get replace with a random token. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">130</span>        <span class="n">x</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="n">mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">mask_token</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-14'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-14'>#</a>
            </div>
            <p>Assign random tokens </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">132</span>        <span class="n">x</span><span class="p">[</span><span class="n">random_token_idx</span><span class="p">]</span> <span class="o">=</span> <span class="n">random_tokens</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-15'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-15'>#</a>
            </div>
            <p>Assign token <code  class="highlight"><span></span><span class="p">[</span><span class="n">PAD</span><span class="p">]</span></code>
 to all the other locations in the labels. The labels equal to <code  class="highlight"><span></span><span class="p">[</span><span class="n">PAD</span><span class="p">]</span></code>
 will not be used in the loss. </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">136</span>        <span class="n">y</span><span class="o">.</span><span class="n">masked_fill_</span><span class="p">(</span><span class="o">~</span><span class="n">full_mask</span><span class="p">,</span> <span class="bp">self</span><span class="o">.</span><span class="n">padding_token</span><span class="p">)</span></pre></div>
        </div>
    </div>
    <div class='section' id='section-16'>
        <div class='docs'>
            <div class='section-link'>
                <a href='#section-16'>#</a>
            </div>
            <p>Return the masked input and the labels </p>

        </div>
        <div class='code'>
            <div class="highlight"><pre><span class="lineno">139</span>        <span class="k">return</span> <span class="n">x</span><span class="p">,</span> <span class="n">y</span></pre></div>
        </div>
    </div>
    <div class='footer'>
        <a href="https://labml.ai">labml.ai</a>
    </div>
</div>
<script src=../../interactive.js?v=1"></script>
<script>
    function handleImages() {
        var images = document.querySelectorAll('p>img')

        for (var i = 0; i < images.length; ++i) {
            handleImage(images[i])
        }
    }

    function handleImage(img) {
        img.parentElement.style.textAlign = 'center'

        var modal = document.createElement('div')
        modal.id = 'modal'

        var modalContent = document.createElement('div')
        modal.appendChild(modalContent)

        var modalImage = document.createElement('img')
        modalContent.appendChild(modalImage)

        var span = document.createElement('span')
        span.classList.add('close')
        span.textContent = 'x'
        modal.appendChild(span)

        img.onclick = function () {
            console.log('clicked')
            document.body.appendChild(modal)
            modalImage.src = img.src
        }

        span.onclick = function () {
            document.body.removeChild(modal)
        }
    }

    handleImages()
</script>
</body>
</html>